package cn.schoolwow.dns.resolver;

import cn.schoolwow.dns.constants.DictionaryKey;
import cn.schoolwow.dns.domain.DNSRequest;
import cn.schoolwow.dns.domain.DNSResponse;
import cn.schoolwow.dns.domain.header.Header;
import cn.schoolwow.dns.domain.header.constant.AA;
import cn.schoolwow.dns.domain.header.constant.RA;
import cn.schoolwow.dns.domain.header.constant.RCODE;
import cn.schoolwow.dns.domain.question.Question;
import cn.schoolwow.dns.domain.question.constants.QTYPE;
import cn.schoolwow.dns.domain.rr.ResourceRecord;
import cn.schoolwow.dns.entity.*;
import cn.schoolwow.quickdao.dao.DAO;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import javax.sql.rowset.serial.SerialBlob;
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetSocketAddress;
import java.sql.Blob;
import java.sql.SQLException;
import java.util.concurrent.ThreadPoolExecutor;

/**DNS处理线程*/
@Component
public class DNSResolver implements InitializingBean {
    private Logger logger = LoggerFactory.getLogger(DNSResolver.class);

    @Resource
    private ThreadPoolExecutor dnsResolverThreadPool;

    @Resource
    private DAO dao;

    @Override
    public void afterPropertiesSet() throws Exception {
        logger.info("[监听udp协议53端口]");
        new Thread(()->{
            try {
                DatagramSocket requestDatagramSocket = new DatagramSocket(53);
                while(true){
                    byte[] data = new byte[512];
                    DatagramPacket packet = new DatagramPacket(data, data.length);
                    requestDatagramSocket.receive(packet);
                    //如果报文长度小于12字节,则忽略改报文
                    if(packet.getLength()<12){
                        logger.warn("[忽略报文]该报文长度小于12字节!当前报文长度:"+packet.getLength());
                        continue;
                    }
                    byte[] trimData = new byte[packet.getLength()];
                    System.arraycopy(packet.getData(),0,trimData,0,packet.getLength());
                    packet.setData(trimData);
                    //新建子线程执行处理过程
                    dnsResolverThreadPool.execute(()->{
                        try {
                            DNSResponse dnsResponse = handleDNSRequest(packet.getPort(), trimData);
                            byte[] responseData = dnsResponse.toByteArray();
                            DatagramPacket responseDatagramPacket = new DatagramPacket(responseData, responseData.length);
                            DatagramSocket responseDatagramSocket = new DatagramSocket();
                            responseDatagramSocket.connect(packet.getAddress(), packet.getPort());
                            responseDatagramSocket.send(responseDatagramPacket);
                            responseDatagramSocket.close();
                        } catch (IOException e) {
                            e.printStackTrace();
                            //插入错误报文记录
                            try {
                                DNSErrorDatagram dnsErrorDatagram = new DNSErrorDatagram();
                                Blob blob = new SerialBlob(trimData);
                                dnsErrorDatagram.setBlob(blob);
                                dnsErrorDatagram.setNormal(false);
                                dnsErrorDatagram.setMessage(e.getMessage());
                                dao.insert(dnsErrorDatagram);
                            } catch (SQLException ex) {
                                ex.printStackTrace();
                            }
                        }
                    });
                }
            }catch (IOException e){
                e.printStackTrace();
            }
        }).start();
    }

    /**处理DNS请求*/
    private DNSResponse handleDNSRequest(int port, byte[] trimData) throws IOException {
        DNSRequest dnsRequest = new DNSRequest(trimData);
        logger.debug("[接收DNS查询请求]{}",dnsRequest);
        if(dnsRequest.getQuestions().length==1){
            Question question = dnsRequest.getQuestions()[0];
            DNSResponse dnsResponse = null;
            DNSHandleRecord dnsHandleRecord = new DNSHandleRecord();
            dnsHandleRecord.setPort(port);
            dnsHandleRecord.setTransactionId(dnsRequest.getHeader().getID());
            dnsHandleRecord.setDomain(question.getQNAME());
            if(null!=question.getQTYPE()){
                dnsHandleRecord.setType(question.getQTYPE().name());
            }
            //插入处理记录
            try{
                dnsResponse = handleSingleQuestion(dnsRequest,question);
                Header header = dnsResponse.getHeader();
                dnsHandleRecord.setRcode(header.getRCODE().name());
                dnsHandleRecord.setAa(header.getAA().value==1);
                dnsHandleRecord.setRa(header.getRA().value==1);
                if(dnsResponse.getAnswers().size()>0){
                    ResourceRecord answer = dnsResponse.getAnswers().get(0);
                    dnsHandleRecord.setValue(answer.getRDATAFormat());
                    dnsHandleRecord.setTtl(dnsResponse.getAnswers().get(0).getTTL());
                }else{
                    dnsHandleRecord.setRcode(RCODE.SERVER_FAILURE.name());
                }
            }catch (IOException e){
                dnsResponse = dnsRequest.getDNSResponse();
                dnsResponse.rcode(RCODE.SERVER_FAILURE);
                dnsHandleRecord.setRcode(RCODE.SERVER_FAILURE.name());
                dnsHandleRecord.setReason(e.getClass().getName() + ": " + e.getMessage());
            }finally {
                dao.insert(dnsHandleRecord);
            }
            return dnsResponse;
        }else{
            return handleMultiQuestion(dnsRequest);
        }
    }

    /**处理问题只有一个的情况*/
    private DNSResponse handleSingleQuestion(DNSRequest dnsRequest, Question question) throws IOException {
        DNSResponse dnsResponse = dnsRequest.getDNSResponse();
        //判断数据库中是否有匹配记录
        DNSRecord dnsRecord = dao.fetch(DNSRecord.class,"domain",question.getQNAME());
        if(null!=dnsRecord){
            ResourceRecord answer = question.answer();
            answer.setTTL(dnsRecord.getTtl());
            answer.setRDATA(QTYPE.A,dnsRecord.getValue());
            dnsResponse.answer(answer);
            dnsResponse.aa(AA.AUTHORITATIVE_ANSWER);
            return dnsResponse;
        }
        //获取递归查询服务器列表
        DNSRecursionServer recursionQueryRecord = dao.query(DNSRecursionServer.class)
                .addQuery("enable",true)
                .orderBy("order")
                .limit(0,1)
                .execute()
                .getOne();
        if(null==recursionQueryRecord){
            dnsResponse.ra(RA.CAN_NOT_RECURSIVE)
                    .rcode(RCODE.SERVER_FAILURE);
            throw new IOException("本地数据库无对应DNS记录且未启用任何DNS递归服务器!");
        }
        return forwardRequest(dnsRequest, recursionQueryRecord.getDnsServerIP());
    }

    /**
     * 转发到上级DNS服务器
     * @param ip 上级DNS服务ip地址
     * */
    public DNSResponse forwardRequest(DNSRequest dnsRequest, String ip) throws IOException {
        DatagramSocket socket = new DatagramSocket();
        byte[] data = dnsRequest.toByteArray();
        DatagramPacket packet = new DatagramPacket(data, data.length);
        socket.connect(new InetSocketAddress(ip,53));
        socket.send(packet);
        data = new byte[512];
        packet.setData(data);
        socket.receive(packet);
        byte[] trimData = new byte[packet.getLength()];
        System.arraycopy(data,0,trimData,0,trimData.length);
        try{
            DNSResponse dnsResponse = new DNSResponse(trimData);
            return dnsResponse;
        }catch (Exception e){
            //插入错误报文记录
            try {
                DNSErrorDatagram dnsErrorDatagram = new DNSErrorDatagram();
                Blob blob = new SerialBlob(trimData);
                dnsErrorDatagram.setBlob(blob);
                dnsErrorDatagram.setNormal(false);
                dnsErrorDatagram.setMessage(e.getMessage());
                dao.insert(dnsErrorDatagram);
            } catch (SQLException ex) {
                ex.printStackTrace();
            }
            throw e;
        }finally {
            socket.close();
        }
    }

    /**处理问题有多个的情况*/
    private DNSResponse handleMultiQuestion(DNSRequest dnsRequest){
        //直接拒绝处理
        DNSResponse dnsResponse = dnsRequest.getDNSResponse();
        return dnsResponse.rcode(RCODE.QUERY_REFUSED);
    }
}
